library(broom)
library(dplyr)
library(ebal)
library(fixest)
library(ggplot2)
library(ggrepel)
library(gridExtra)
library(here)
library(kableExtra)
library(knitr)
library(readr)
library(tidyr)
library(wru) # make sure to use 2020 census data, options("wru_data_wd" = TRUE)

datadir = here::here('data')
outputdir = here::here('output')

round_num <- function(num, digits){
  return(format(round(num, digits), nsmall=digits))
}

summary_table <- function(data, format="latex", keys, period, row_names){
  n <- length(keys)
  res1 <- matrix(NA, n, 2)
  res2 <- matrix(NA, n, 2)
  colnames(res2) <- colnames(res1) <- c("Mean", "St Dev")
  
  for (i in 1:n){
    
    c <- data %>% 
      filter(treat==0) %>% 
      select(num_range(keys[i], period)) %>% 
      unlist() %>% unname()
    c_mean <- mean(c, na.rm=TRUE)
    c_sd <- sd(c, na.rm=TRUE)
    res1[i, ] <- c(c_mean, c_sd)
    
    t <- data %>% 
      filter(treat==1) %>% 
      select(num_range(keys[i], period)) %>% 
      unlist() %>% unname()
    t_mean <- mean(t, na.rm=TRUE)
    t_sd <- sd(t, na.rm=TRUE)
    res2[i, ] <- c(t_mean, t_sd)
  }
  rownames(res1) <- rownames(res2) <- row_names
  res <- round_num(cbind(res1, res2), 1)
  
  return(res)
}

authors <- read_csv(file.path(datadir, 'authors_for_analysis.csv'))

#############################
# Figure 1
#############################

collab_trend_us <- read_csv(file.path(datadir, 'yearly_pb_trend_US.csv'))
colnames(collab_trend_us) <- c("index", "UK", "China", "Germany", "Japan", "France", "Canada", "Australia", "Switzerland", "Netherlands", "US")

collab_trend_us <- collab_trend_us |> mutate(across(UK:Netherlands, ~ .x/US)) |> 
  select(!US) |> pivot_longer(cols = UK:Netherlands, names_to = "Country", values_to = "val") |>
  mutate(Country = factor(Country , levels=c("China", "UK", "Germany", "Japan", "France", "Australia", "Canada", "Switzerland", "Netherlands")))

p1 <- ggplot(data=collab_trend_us, aes(x=index, y=val, fill=Country, color=Country)) + 
    geom_point(aes(shape=Country), size=2.8)+ geom_line(linewidth=0.8) + 
    theme_bw() + labs(x='Year', y='Share of U.S. Publications') + 
    scale_shape_manual(values = 1:9)+ geom_vline(xintercept=2018, linetype='longdash') +
    scale_x_continuous(breaks=seq(2010, 2021, 2)) + theme(text = element_text(size = 18))

ggsave(file.path(outputdir, '/yearly_trend_us.png'), p1, width=9, height=7)

rm(list=c("collab_trend_us", "p1"))


#############################
# Table 1: Descriptive Stats
#############################

keys <- c("cite_total_", "cite_pubmed_", "cite_nih_", "pub_pubmed_")

vnames <- c("Total Citations", "PubMed Citations", "NIH Citations", "PubMed Publications")

tab1_pre <- summary_table(data=authors, keys=keys, period=c(2015:2018), row_names=vnames)
tab1_post <- summary_table(data=authors, keys=keys, period=c(2019:2021), row_names=vnames)

tab1 <- rbind(tab1_pre, tab1_post)
kbl(tab1, booktabs = T, format='latex') |>
  kable_styling(latex_options = c("scale_down"), position = "center") |>
  add_header_above(c(" " = 1, "Control Group" = 2, "Treatment Group" = 2)) |>
  pack_rows("Pre-treatment (2015-2018)", 1, length(vnames)) |>
  pack_rows("Post-treatment (2019-2021)", length(vnames)+1, length(vnames)*2)

authors |>
  group_by(treat) |>
  summarize(sum(Asian))

rm(list=c("tab1", "tab1_pre", "tab1_post", "keys", "vnames"))


#############################
# Figure 2
#############################

f2_pub <- authors |> # PubMed publications
  select(author_researcher_id, treat, num_range('pub_pubmed_', 2015:2021)) |>
  pivot_longer(
    cols = num_range("pub_pubmed_", 2015:2021),
    names_to = "year",
    names_prefix='pub_pubmed_',
    values_to = "Y") |>
  mutate(y_log = log(Y+1)) |>
  group_by(treat, year) |>
  summarise(Y = mean(y_log)) |>
  ungroup() |>
  mutate(treat = ifelse(treat==1, 'Treated', 'Control'))

f2_pub[8:14, 'Y'] <- f2_pub[8:14, 'Y'] - f2_pub[1:7, 'Y']
f2_pub <- f2_pub[8:14,]

p2 <- ggplot(data=f2_pub, aes(x=year, y=Y, group=treat, color=treat)) + 
  geom_point() + geom_line() +
  geom_vline(xintercept=4, linetype='longdash') +
  theme_bw() + labs(x='Year', y='Treated - Control') +
  ggtitle('Difference in PubMed Publications (log)') + 
  theme(legend.position="none") + theme(text = element_text(size = 14))

f2_cite <- authors |> # PubMed citations
  select(author_researcher_id, treat, num_range('cite_pubmed_', 2015:2021)) |>
  pivot_longer(
    cols = num_range("cite_pubmed_", 2015:2021),
    names_to = "year",
    names_prefix='cite_pubmed_',
    values_to = "Y") |>
  mutate(y_log = log(Y+1)) |>
  group_by(treat, year) |>
  summarise(Y = mean(y_log)) |>
  ungroup() |>
  mutate(treat = ifelse(treat==1, 'Treated', 'Control'))

f2_cite[8:14, 'Y'] <- f2_cite[8:14, 'Y'] - f2_cite[1:7, 'Y']
f2_cite <- f2_cite[8:14,]

p3 <- ggplot(data=f2_cite, aes(x=year, y=Y, group=treat, color=treat)) + 
  geom_point() + geom_line() +
  geom_vline(xintercept=4, linetype='longdash') +
  theme_bw() + labs(x='Year', y='Treated - Control') +
  ggtitle('Difference in PubMed Citations (log)') + 
  theme(legend.position="none") + theme(text = element_text(size = 14))

p3 <- arrangeGrob(p2, p3, nrow = 1)
ggsave(file.path(outputdir, 'productivity_trend_log_diff.png'), dpi=300, p3, width=10, height=5)

rm(list=c("f2_pub", "f2_cite", "p2", "p3"))


#############################
# Figure 3
#############################

para_pb <- authors |>
  select(author_researcher_id, treat, num_range('cite_pubmed_', 2015:2021)) |>
  pivot_longer(
    cols = num_range('cite_pubmed_', 2015:2021),
    names_to = "year",
    names_prefix='cite_pubmed_',
    values_to = "Y_raw") |>
  mutate(
    period = rep(2015:2021, n()/7), 
    Y = log(Y_raw+1),
    post_treatment = as.numeric(period >= 2019),
    treatment = post_treatment * treat,
    author_id = as.character(author_researcher_id)
  )

out1 <- feols(Y~ i(period, treat, 2018)|period+author_id, 
              data=para_pb, cluster=~author_id)
out1 <- tidy(out1)
out1[5:7, ] <- out1[4:6,] # add 2018
out1$term <- c(2015:2021)
out1[out1$term==2018,2:5] <- 0
out1$lower <- with(out1, estimate-1.96*std.error)
out1$upper <- with(out1, estimate+1.96*std.error)
out1[out1$term==2018, 6:7] <- NA

p4 <- ggplot(out1, aes(term, estimate))+ 
  geom_errorbar(aes(ymin = lower, ymax = upper), width = 0.1, size=0.5, color='black') + 
  geom_point(aes(x=term, y=estimate), size=2, color='black') +
  theme_bw() + geom_hline(yintercept = 0) + geom_vline(xintercept = 2018, linetype='longdash', color='red') +
  labs(y='Estimates and 95% Conf. Int.', x='Year') + ggtitle('Effect on PubMed Citations') + 
  theme(text = element_text(size = 14)) + ylim(-0.34, 0.1)


para_total <- authors |>
  select(author_researcher_id, treat, num_range('cite_total_', 2015:2021))

para_total <- authors |>
  select(author_researcher_id, treat, num_range('cite_total_', 2015:2021)) |>
  pivot_longer(
    cols = num_range('cite_total_', 2015:2021),
    names_to = "year",
    names_prefix='cite_total_',
    values_to = "Y_raw") |>
  mutate(
    period = rep(2015:2021, n()/7), 
    Y = log(Y_raw+1),
    post_treatment = as.numeric(period >= 2019),
    treatment = post_treatment * treat,
    author_id = as.character(author_researcher_id)
  )

out2 <- feols(Y ~ i(period, treat, 2018)|period+author_id, 
              data=para_total, cluster=~author_id)
out2 <- tidy(out2)
out2[5:7, ] <- out2[4:6,] # add 2018
out2$term <- c(2015:2021)
out2[out1$term==2018,2:5] <- 0
out2$lower <- with(out2, estimate-1.96*std.error)
out2$upper <- with(out2, estimate+1.96*std.error)
out2[out2$term==2018, 6:7] <- NA

p5 <- ggplot(out2, aes(term, estimate))+ 
  geom_errorbar(aes(ymin = lower, ymax = upper), width = 0.1, size=0.5, color='black') + 
  geom_point(aes(x=term, y=estimate), size=2, color='black') +
  theme_bw() + geom_hline(yintercept = 0) + geom_vline(xintercept = 2018, linetype='longdash', color='red') +
  labs(y='Estimates and 95% Conf. Int.', x='Year') + ggtitle('Effect on Total Citations') + 
  theme(text = element_text(size = 14)) + ylim(-0.34, 0.1)

# ebal
para_eb_pb <- authors |>
  select(author_researcher_id, treat, pub_total_1014, cite_total_1014, 
         pub_nih_1014, Asian, num_range('cite_pubmed_', 2015:2021))

w <- ebalance(X=select(para_eb_pb, c(pub_total_1014, cite_total_1014, pub_nih_1014, Asian)),
              Treatment=para_eb_pb$treat)$w
para_eb_pb$weights <- 1
para_eb_pb$weights[para_eb_pb$treat==0] <- w

para_eb_pb <- para_eb_pb |>
  pivot_longer(
    cols = num_range('cite_pubmed_', 2015:2021),
    names_to = "year",
    names_prefix='cite_pubmed_',
    values_to = "Y_raw") |>
  mutate(
    period = rep(2015:2021, n()/7), 
    Y = log(Y_raw+1),
    post_treatment = as.numeric(period >= 2019),
    treatment = post_treatment * treat,
    author_id = as.character(author_researcher_id)
  )

out3 <- feols(Y~ i(period, treat, 2018)|period+author_id, 
              data=para_eb_pb, weights=~weights, cluster=~author_id)
out3 <- tidy(out3)
out3[5:7, ] <- out3[4:6,] # add 2018
out3$term <- c(2015:2021)
out3[out3$term==2018,2:5] <- 0
out3$lower <- with(out3, estimate-1.96*std.error)
out3$upper <- with(out3, estimate+1.96*std.error)
out3[out3$term==2018, 6:7] <- NA

p6 <- ggplot(out3, aes(term, estimate))+ 
  geom_errorbar(aes(ymin = lower, ymax = upper), width = 0.1, size=0.5, color='black') + 
  geom_point(aes(x=term, y=estimate), size=2, color='black') +
  theme_bw() + geom_hline(yintercept = 0) + geom_vline(xintercept = 2018, linetype='longdash', color='red') +
  labs(y='Estimates and 95% Conf. Int.', x='Year') + ggtitle('Effect on PubMed Citations') + 
  theme(text = element_text(size = 14)) + ylim(-0.34, 0.1)

para_eb_total <- authors |>
  select(author_researcher_id, treat, num_range('cite_total_', 2015:2021))
para_eb_total$weights <- 1
para_eb_total$weights[para_eb_total$treat==0] <- w

para_eb_total <- para_eb_total |>
  pivot_longer(
    cols = num_range('cite_total_', 2015:2021),
    names_to = "year",
    names_prefix='cite_total_',
    values_to = "Y_raw") |>
  mutate(
    period = rep(2015:2021, n()/7), 
    Y = log(Y_raw+1),
    post_treatment = as.numeric(period >= 2019),
    treatment = post_treatment * treat,
    author_id = as.character(author_researcher_id)
  )

out4 <- feols(Y~ i(period, treat, 2018)|period+author_id, 
              data=para_eb_total, weights=~weights, cluster=~author_id)
out4 <- tidy(out4)
out4[5:7, ] <- out4[4:6,] # add 2018
out4$term <- c(2015:2021)
out4[out4$term==2018,2:5] <- 0
out4$lower <- with(out4, estimate-1.96*std.error)
out4$upper <- with(out4, estimate+1.96*std.error)
out4[out4$term==2018, 6:7] <- NA

p7 <- ggplot(out4, aes(term, estimate))+ 
  geom_errorbar(aes(ymin = lower, ymax = upper), width = 0.1, size=0.5, color='black') + 
  geom_point(aes(x=term, y=estimate), size=2, color='black') +
  theme_bw() + geom_hline(yintercept = 0) + geom_vline(xintercept = 2018, linetype='longdash', color='red') +
  labs(y='Estimates and 95% Conf. Int.', x='Year') + ggtitle('Effect on Total Citations') + 
  theme(text = element_text(size = 14)) + ylim(-0.34, 0.1)

p8 <- arrangeGrob(arrangeGrob(p4, p5, left="Panel A", ncol=2), arrangeGrob(p6, p7, left="Panel B", ncol=2), nrow=2)
ggsave(file.path(outputdir, 'citation_event.png'), dpi=300, p8, width=10, height=9)

rm(list=c("p4", "p5", "p6", "p7", "p8", "out1", "out2", "out3", "out4",
  "para_pb", "para_total", "para_eb_pb", "para_eb_total", "w"))


#############################
# Table 2
#############################

authors_main <- authors |>
  mutate(
    across(num_range('pub_total_', 2015:2021), .names = "pub_non_pubmed_{2015:2021}") - 
      across(num_range('pub_pubmed_', 2015:2021)),
    across(num_range('cite_total_', 2015:2021), .names = "cite_non_pubmed_{2015:2021}") - 
      across(num_range('cite_pubmed_', 2015:2021))
    )

outcome_keys <- c('pub_pubmed_', 'cite_pubmed_', 'pub_non_pubmed_', 'cite_non_pubmed_', 'pub_total_', 'cite_total_')
res <- matrix(NA, 27, 10)

w <- ebalance(X=select(authors_main, c(pub_total_1014, cite_total_1014, pub_nih_1014, Asian)),
              Treatment=authors_main$treat)$w

authors_main$weights <- 1
authors_main$weights[authors_main$treat==0] <- w

for (i in 1:length(outcome_keys)){
  dat <- authors_main |>
    select(author_researcher_id, treat, pub_total_1014, cite_total_1014, pub_nih_1014, Asian,
           num_range(outcome_keys[i], 2015:2021), weights)
  
  row_switch <- c(0, 0, 1, 1, 2, 2)
  col_switch <- c(0, 1, 0, 1, 0, 1)
  row_n <- 1 + row_switch[i]*9
  col_n <- 1 + col_switch[i]*5
  
  panel <- dat |>
    pivot_longer(
      cols = num_range(outcome_keys[i], 2015:2021),
      names_to = "year",
      names_prefix=outcome_keys[i],
      values_to = "Y_raw") |>
    mutate(
      period = rep(2015:2021, n()/7), 
      Y = log(Y_raw+1),
      Y_asinh = asinh(Y_raw),
      post_treatment = as.numeric(period >= 2019),
      treatment = post_treatment * treat,
      author_id = as.character(author_researcher_id)
    )
  
  avg_pre <- panel |> filter(treat==1 & post_treatment==0) |> summarise(avg = mean(Y)) |> unlist() |> round_num(3)
  avg_pre1 <- panel |> filter(treat==1 & post_treatment==0) |> summarise(avg = mean(Y_asinh)) |> unlist() |> round_num(3)
  avg_pre2 <- panel |> filter(treat==1 & post_treatment==0) |> summarise(avg = mean(Y_raw)) |> unlist() |> round_num(3)
  
  out1 <- feols(Y~treatment|period+author_id, data=panel, cluster=~author_id)
  res[row_n:(row_n+8), col_n] <- c(round_num(out1$coefficients, 3), paste0('(', round_num(out1$se, 3), ')'), avg_pre, round_num(r2(out1, type='ar2'), 3),
                                   out1$nobs, 'Y', 'Y', '', '')
  
  out2 <- feols(Y~treatment|period+period[pub_total_1014]+period[cite_total_1014]+period[pub_nih_1014]+period[Asian]+author_id, data=panel, cluster=~author_id)
  res[row_n:(row_n+8), col_n+1] <- c(round_num(out2$coefficients, 3), paste0('(', round_num(out2$se, 3), ')'), avg_pre, round_num(r2(out2, type='ar2'), 3),
                                     out2$nobs, 'Y', 'Y', 'Y', '')
  
  out3 <- feols(Y~treatment|period+author_id, data=panel, weights=~weights, cluster=~author_id)
  res[row_n:(row_n+8), col_n+2] <- c(round_num(out3$coefficients, 3), paste0('(', round_num(out3$se, 3), ')'), avg_pre, round_num(r2(out3, type='ar2'), 3),
                                     out3$nobs, 'Y', 'Y', '', 'Y')

  out4 <- feols(Y_asinh~treatment|period+period[pub_total_1014]+period[cite_total_1014]+period[pub_nih_1014]+period[Asian]+author_id, data=panel, cluster=~author_id)
  res[row_n:(row_n+8), col_n+3] <- c(round_num(out4$coefficients, 3), paste0('(', round_num(out4$se, 3), ')'), avg_pre1, round_num(r2(out4, type='ar2'), 3),
                                     out4$nobs, 'Y', 'Y', '', 'Y')

  out5 <- fepois(Y_raw~treatment|period+period[pub_total_1014]+period[cite_total_1014]+period[pub_nih_1014]+period[Asian]+author_id, data=panel, cluster=~author_id)
  res[row_n:(row_n+8), col_n+4] <- c(round_num(out5$coefficients, 3), paste0('(', round_num(out5$se, 3), ')'), avg_pre2, round_num(r2(out5, type='pr2'), 3),
                                     out5$nobs, 'Y', 'Y', '', 'Y')
}

row_names <- rep(c('Ties to China × Post', '', 'Pre-treatment avg.', 'R2', 
                   'No. of obs.', 'Scholar FE', 'Year FE', 'Baseline Covariates*Year FE', 
                   'Entropy Balancing'), 3)

res <- cbind(row_names, res)
colnames(res) <- NULL

kbl(res, booktabs = T, format='latex') |>
  kable_styling(latex_options = c("scale_down"), position = "center") |>
  add_header_above(c("", "PubMed Publications" = 3, "PubMed Publications" = 3)) |>
  add_header_above(c("", "(1)", "(2)", "(3)", "(4)", "(5)", "(6)")) |>
  pack_rows("Panel A", 1, 9) |>
  pack_rows("Panel B", 10, 18) |>
  pack_rows("Panel C", 19, 27)

rm(list=c("authors_main", "dat", "out1", "out2", "out3", "out4", "out5", "res", "panel", 
  "avg_pre", "avg_pre1", "avg_pre2", "row_switch", "col_switch", "row_n", "col_n", "i", 
  "row_names", "outcome_keys", "w"))


#############################
# Figure 4
#############################
# here we only provide the estimates from the institutional analysis for data privacy reasons
res <- read_csv(file.path(datadir, "figure4_estimates.csv"))
res <- res |> mutate(Institutions = as.factor(Institutions), inv=as.factor(inv))
labels <- res$labels

p9 <- ggplot(res, aes(x=Institutions, y=coef, fill=Institutions))+ 
  geom_errorbar(aes(ymin = conf_low, ymax = conf_high, color=inv), width = 0.3, size=0.5) + 
  geom_point(aes(x=Institutions, y=coef, color=inv), size=2) +
  scale_color_manual("inv", values=c("black", "red"), guide = 'none')+
  theme_bw() + geom_hline(yintercept = 0) + 
  labs(x='Institutions', y='Estimates and 95% Conf. Int.') +
  scale_fill_manual(values = rep("black", length(labels)), labels = labels) +
  theme(legend.direction ="vertical",legend.position = "bottom") + guides(fill=guide_legend(ncol=5)) + 
  theme(legend.text=element_text(size=10))

ggsave(file.path(outputdir, 'inst.png'), p9, height=12, width=16)

rm(list=c("p9", "res", "labels"))


#############################
# Table 3
#############################

authors_asian <- authors |>
  mutate(
    across(num_range('cite_total_', 2015:2021), .names = "cite_non_cn_funded_{2015:2021}") - 
      across(num_range('cite_cn_funded_', 2015:2021)),
    across(num_range('cite_total_', 2015:2021), .names = "cite_non_nih_{2015:2021}") - 
      across(num_range('cite_nih_', 2015:2021))
  )

outcome_keys <- c('cite_total_', 'cite_nih_', 'cite_non_nih_', 'cite_cn_funded_', 'cite_non_cn_funded_')

res <- matrix(NA, 11, 5)
for (i in 1:length(outcome_keys)){
  dat <- authors_asian |>
    select(author_researcher_id, treat, pub_total_1014, cite_total_1014, pub_nih_1014, Asian,
           num_range(outcome_keys[i], 2015:2021))
  
  panel <- dat |>
    pivot_longer(
      cols = num_range(outcome_keys[i], 2015:2021),
      names_to = "year",
      names_prefix=outcome_keys[i],
      values_to = "Y_raw") |>
    mutate(
      period = rep(2015:2021, n()/7), 
      Y = log(Y_raw+1),
      post_treatment = as.numeric(period >= 2019),
      treatment = post_treatment * treat,
      author_id = as.character(author_researcher_id)
    )
  
  out <- feols(Y~treatment*Asian + post_treatment*Asian|
                 period+period[pub_total_1014]+period[cite_total_1014]+period[pub_nih_1014]+
                 author_id, data=panel, cluster=~author_id)
  
  res[c(3, 1, 5), i] <- round_num(unname(out$coefficients), 3)
  res[c(4, 2, 6), i] <- paste0('(', round_num(out$se, 3), ')')
  res[7:11, i] <- c(round_num(r2(out, type='ar2'), 3), out$nobs, 'Y', 'Y', 'Y')
}

row_names <- c('Ties to China × Post × Asian', '', 'Ties to China × Post', '', 'Post × Asian', '',
               'R2', 'No. of obs.', 'Scholar FE', 'Year FE', 'Baseline Covariates*Year FE')

res <- cbind(row_names, res)
colnames(res) <- c('', 'All', 'NIH-Funded', 'Non NIH-Funded', 'China-Funded', 'Non China-Funded')

kbl(res, booktabs = T, format='latex') |>
  kable_styling(latex_options = c("scale_down"), position = "center") |>
  add_header_above(c("", "Citations by Nature of Publication" = 5))

rm(list=c("authors_asian", "outcome_keys", "res", "dat", "panel", "out", "i", "row_names"))


#############################
# Figure 5
#############################
# here we only provide the estimates from the field level analysis for data privacy reasons
dat <- read_csv(file.path(datadir, "figure5_estimates.csv"))

# plot, very fragile, do not temper
xc1 <- 0.01
yc1 <- 0.04/2
rx1 <- 0.006
rx2 <- 0.007
nih_lm <- summary(lm(coef~pct_nih, data=dat, weights=num_total))

g1 <- ggplotGrob(
  ggplot() + xlim(0, 0.6) + ylim(0, 0.04) + theme_void() + 
    annotate("text", x = 0.15, y = 0.04/2, label = "Treatment Effect on Citations", size =5, colour="black") +
    annotate("path", x=xc1+rx1*cos(seq(0,2*pi,length.out=100)), y=yc1+rx2*sin(seq(0,2*pi,length.out=100)), size =1.1) +
    geom_segment(aes(x = 0.3, y = 0.04/2, xend = 0.34, yend = 0.04/2), color="black", linewidth =1.2) + 
    annotate("text", x = 0.47, y = 0.04/2, 
             label = paste0("fitted: coef = ", round_num(nih_lm$coefficients['pct_nih','Estimate'], 2), 
                            ", s.e. = ", round_num(nih_lm$coefficients['pct_nih','Std. Error'], 2)), 
             size =5, colour="black") +
    theme(plot.background = element_rect(colour = "black", size = 1.5))
)

p10 <- ggplot(data=dat, aes(x=pct_nih, y=coef, size = num_total, label=name1)) +
  geom_point(shape=1, stroke = 1.1) +
  scale_size(range = c(1, 40)) +
  geom_smooth(method='lm', formula= y~x, mapping = aes(weight = num_total), se=FALSE, color='black') + theme_bw() +
  theme(legend.position = "none") +
  geom_text_repel(size=4, angle=10) +
  ylim(-0.25, 0.25) + labs(x='Percentage of NIH Funding', y='Estimates') +
  theme(axis.text = element_text(size=20), 
        axis.title = element_text(size=20), 
        axis.line = element_line(colour = 'black', linewidth = 1)) + xlim(0, 0.62)

xc2 <- 0.003
yc2 <- 0.04/2
rx3 <- 0.0015
rx4 <- 0.007
cn_lm <- summary(lm(coef~pct_us_cn_colab, data=dat, weights=num_total))

g2 <- ggplotGrob(
  ggplot() + xlim(0, 0.15) + ylim(0, 0.04) + theme_void() + 
    annotate("text", x = 0.038, y = 0.04/2, label = "Treatment Effect on Citations", size =5, colour="black") +
    annotate("path", x=xc2+rx3*cos(seq(0,2*pi,length.out=100)), y=yc2+rx4*sin(seq(0,2*pi,length.out=100)), size =1.1) +
    geom_segment(aes(x = 0.076, y = 0.04/2, xend = 0.087, yend = 0.04/2), color="black", linewidth =1.2) + 
    annotate("text", x = 0.12, y = 0.04/2, 
             label = paste0("fitted: coef = ", round_num(cn_lm$coefficients['pct_us_cn_colab','Estimate'], 2), 
                            ", s.e. = ", round_num(cn_lm$coefficients['pct_us_cn_colab','Std. Error'], 2)), 
             size =5, colour="black") +
    theme(plot.background = element_rect(colour = "black", size = 1.5))
)

p11 <- ggplot(data=dat, aes(x=pct_us_cn_colab, y=coef, size = num_total, label=name1)) +
  geom_point(shape=1, stroke = 1.1) +
  scale_size(range = c(1, 40)) +
  geom_smooth(method='lm', formula= y~x, mapping = aes(weight = num_total), se=FALSE, color='black') + theme_bw() +
  theme(legend.position = "none") +
  geom_text_repel(size=4, angle=10) +
  ylim(-0.25, 0.25) + labs(x='Percentage of U.S.-CN Collaborations', y='Estimates') +
  theme(axis.text = element_text(size=20), 
        axis.title = element_text(size=20), 
        axis.line = element_line(colour = 'black', linewidth = 1)) + xlim(0, 0.162)

c1 <- p10 + annotation_custom(
  grob = g1, xmin = 0, xmax = 0.6, ymin = -0.335, ymax = -0.364) + 
  theme(plot.margin = unit(c(5.5, 5.5, 70, 5.5), "pt"))

c2 <- p11 + annotation_custom(
  grob = g2, xmin = 0, xmax = 0.155, ymin = -0.335, ymax = -0.365) + 
  theme(plot.margin = unit(c(5.5, 5.5, 70, 5.5), "pt"))

bubble <- arrangeGrob(c1, c2, nrow = 1)
ggsave(file.path(outputdir, 'bubble_combine.png'), dpi=300, bubble, height=10, width=19)

rm(list=c("c1", "c2", "p10", "p11", "bubble", "g1", "g2",
  "xc1", "yc1", "rx1", "rx2", "xc2", "yc2", "rx3", "rx4",
  "nih_lm", "cn_lm"))


#############################
# Figure 6
#############################
country_one <- read_csv(file.path(datadir, '50countries1.csv'))
country_two <- read_csv(file.path(datadir, '50countries2.csv'))

country_fields <- rbind(country_one, country_two) |>
  mutate(category_for_id = paste0('F', category_for_id)) |>
  left_join(dat, by='category_for_id')
  
# China
country_fields_cn <- country_fields |> filter(country!='US')

panel <- country_fields_cn |>
  mutate(
    Y = log(num+1),
    post_treatment = as.numeric(year >= 2019),
    treat = ifelse(country=='CN', 1, 0),
    treatment = post_treatment * treat)

n <- length(unique(country_fields_cn$category_for_id))
out <- rep(NA, n)

dat_fields <- dat$category_for_id
for (i in 1:n){
  p <- panel[panel$category_for_id==dat_fields[i],]
  logged <- feols(Y ~ treatment|country+year, data=p)
  out[i] <- unname(logged$coefficients)
}
dat$did_cn <- out

# U.S.
country_fields_us <-  country_fields |> filter(country!='CN')

panel <- country_fields_us |>
  mutate(
    Y = log(num+1),
    post_treatment = as.numeric(year >= 2019),
    treat = ifelse(country=='US', 1, 0),
    treatment = post_treatment * treat)

out <- rep(NA, n)

for (i in 1:n){
  p <- panel[panel$category_for_id==dat_fields[i],]
  logged <- feols(Y ~ treatment| country+year, data=p)
  out[i] <- unname(logged$coefficients)
}
dat$did_us <- out

# plot, very fragile, do not temper
p12 <- ggplot(data=dat, aes(x=coef, y=did_cn, size = num_total, label=name1)) + 
  geom_point(shape=1, stroke = 1.1) + 
  scale_size(range = c(1, 40)) + 
  geom_smooth(method='lm', formula= y~x, mapping = aes(weight = num_total), se=FALSE, color='black') +
  geom_text_repel(size=3, angle=10) +
  theme_bw() + labs(x='In-sample DiD Estimates', y='Total Publication DiD Estimates') + 
  ggtitle('China') + 
  theme(axis.text = element_text(size=20), 
        axis.title = element_text(size=20), 
        axis.line = element_line(colour = 'black', size = 1),
        plot.title = element_text(size = 28)) + 
  theme(legend.position = "none") +
  ylim(0.1, 0.95) + xlim(-0.25, 0.14)

xc1 <- -0.145
yc1 <- 0.135
r1 <- 0.002
field_cn <- summary(lm(did_cn~coef, data=dat, weights=num_total))

g1 <- ggplotGrob(ggplot() + xlim(-0.15, 0.16) + ylim(0.13, 0.14) + theme_void() + 
                   annotate("text", x = -0.075, y = 0.135, label = "Total Publication DiD Estimates", size =5, colour="black") +
                   annotate("path", x=xc1+r1*cos(seq(0,2*pi,length.out=100)), y=yc1+r1*sin(seq(0,2*pi,length.out=100)), size =1.1) +
                   geom_segment(aes(x = 0.018, y = 0.135, xend = 0.038, yend = 0.135), color="black", size =1.2) + 
                   annotate("text", x = 0.10, y = 0.135, 
                            label = paste0("fitted: coef = ", round_num(field_cn$coefficients['coef','Estimate'], 2), 
                                           ", s.e. = ", round_num(field_cn$coefficients['coef','Std. Error'], 2)), 
                            size =5, colour="black") +
                   theme(plot.background = element_rect(colour = "black", size = 1.5)))


c1 <- p12 + annotation_custom(
  grob = g1, xmin = -0.25, xmax = 0.13, ymin = -0.07, ymax = -0.04) + 
  theme(plot.margin = unit(c(5.5, 5.5, 70, 5.5), "pt"))

p13 <- ggplot(data=dat, aes(x=coef, y=did_us, size = num_total, label=name1)) + 
  geom_point(shape=1, stroke = 1.1) + 
  scale_size(range = c(1, 40)) + 
  geom_smooth(method='lm', formula= y~x, mapping = aes(weight = num_total), se=FALSE, color='black') +
  geom_text_repel(size=4, angle=10) +
  theme_bw() + labs(x='In-sample DiD Estimates', y='Total Publication DiD Estimates') + 
  ggtitle('U.S.') + 
  theme(axis.text = element_text(size=20), 
        axis.title = element_text(size=20), 
        axis.line = element_line(colour = 'black', size = 1),
        plot.title = element_text(size = 28)) + 
  theme(legend.position = "none") +
  ylim(-0.52, -0.07) + xlim(-0.25, 0.14)

xc2 <- -0.145
yc2 <- -0.325
r2 <- 0.002
field_us <- summary(lm(did_us~coef, data=dat, weights=num_total))

g2 <- ggplotGrob(ggplot() + xlim(-0.15, 0.16) + ylim(-0.33, -0.32) + theme_void() + 
                   annotate("text", x = -0.07, y = -0.325, label = "Total Publication DiD Estimates", size =5, colour="black") +
                   annotate("path", x=xc2+r2*cos(seq(0,2*pi,length.out=100)), y=yc2+r2*sin(seq(0,2*pi,length.out=100)), size =1.1) +
                   geom_segment(aes(x = 0.015, y = -0.325, xend = 0.035, yend = -0.325), color="black", size =1.2) + 
                   annotate("text", x = 0.10, y = -0.325, 
                            label = paste0("fitted: coef = ", round_num(field_us$coefficients['coef','Estimate'], 2), 
                                           ", s.e. = ", round_num(field_us$coefficients['coef','Std. Error'], 2)), 
                            size =5, colour="black") +
                   theme(plot.background = element_rect(colour = "black", size = 1.5)))


c2 <- p13 + annotation_custom(
  grob = g2, xmin = -0.25, xmax = 0.13, ymin = -0.61, ymax = -0.5935) + 
  theme(plot.margin = unit(c(5.5, 5.5, 70, 5.5), "pt"))

p14 <- arrangeGrob(c1, c2, nrow = 1)
ggsave(file.path(outputdir, 'did_combine.png'), dpi=300, p14, height=10, width=18)

